{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1 align=\"center\">Non-Rigid Registration: Demons</h1>\n",
    "\n",
    "This notebook illustrates the use of the Demons based non-rigid registration set of algorithms in SimpleITK. These include both the DemonsMetric which is part of the registration framework and Demons registration filters which are not.\n",
    "\n",
    "The data we work with is a 4D (3D+time) thoracic-abdominal CT, the Point-validated Pixel-based Breathing Thorax Model (POPI) model. This data consists of a set of temporal CT volumes, a set of masks segmenting each of the CTs to air/body/lung, and a set of corresponding points across the CT volumes. \n",
    "\n",
    "The POPI model is provided by the Léon Bérard Cancer Center & CREATIS Laboratory, Lyon, France. The relevant publication is:\n",
    "\n",
    "J. Vandemeulebroucke, D. Sarrut, P. Clarysse, \"The POPI-model, a point-validated pixel-based breathing thorax model\",\n",
    "Proc. XVth International Conference on the Use of Computers in Radiation Therapy (ICCR), Toronto, Canada, 2007.\n",
    "\n",
    "The POPI data, and additional 4D CT data sets with reference points are available from the CREATIS Laboratory <a href=\"http://www.creatis.insa-lyon.fr/rio/popi-model?action=show&redirect=popi\">here</a>. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "library(SimpleITK)\n",
    "\n",
    "# If the environment variable SIMPLE_ITK_MEMORY_CONSTRAINED_ENVIRONMENT is set, this will override the ReadImage\n",
    "# function so that it also resamples the image to a smaller size (testing environment is memory constrained).\n",
    "source(\"setup_for_testing.R\")\n",
    "\n",
    "library(ggplot2)\n",
    "library(tidyr)\n",
    "library(purrr)\n",
    "# Utility method that either downloads data from the Girder repository or\n",
    "# if already downloaded returns the file name for reading from disk (cached data).\n",
    "source(\"downloaddata.R\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utilities\n",
    "Utility methods used in the notebook for display and registration evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "source(\"registration_utilities.R\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Data\n",
    "\n",
    "Load all of the images, masks and point data into corresponding lists. If the data is not available locally it will be downloaded from the original remote repository. \n",
    "\n",
    "Take a look at a temporal slice for a specific coronal index (center of volume). According to the documentation on the POPI site, volume number one corresponds to end inspiration (maximal air volume).\n",
    "\n",
    "You can modify the coronal index to look at other temporal slices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "simpleitk_error_allowed": "Exception in SITK"
   },
   "outputs": [],
   "source": [
    "body_label <- 0\n",
    "air_label <- 1\n",
    "lung_label <- 2    \n",
    "\n",
    "image_file_names <- file.path(\"POPI\", \"meta\", paste0(0:9, \"0-P.mhd\"))\n",
    "# Read the CT images as 32bit float, the pixel type required for registration.\n",
    "image_list <- lapply(image_file_names, function(image_file_name) ReadImage(fetch_data(image_file_name), \"sitkFloat32\"))    \n",
    "\n",
    "mask_file_names <- file.path(\"POPI\", \"masks\", paste0(0:9, \"0-air-body-lungs.mhd\"))\n",
    "mask_list <- lapply(mask_file_names, function(mask_file_name) ReadImage(fetch_data(mask_file_name)))    \n",
    "\n",
    "points_file_names <- file.path(\"POPI\", \"landmarks\", paste0(0:9, \"0-Landmarks.pts\"))\n",
    "points_list <- lapply(points_file_names, function(points_file_name) read.table(fetch_data(points_file_name)))\n",
    "    \n",
    "# Look at a temporal slice for the specific coronal index     \n",
    "coronal_index <- as.integer(round(image_list[[1]]$GetHeight()/2.0))\n",
    "temporal_slice <- temporal_coronal_with_overlay(coronal_index, image_list, mask_list, lung_label, -1024, 976)\n",
    "    # Flip the image so that it corresponds to the standard radiological display.\n",
    "Show(temporal_slice[,seq(temporal_slice$GetHeight(),0,-1),])  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Demons Registration\n",
    "\n",
    "This function will align the fixed and moving images using the Demons registration method. If given a mask, the similarity metric will be evaluated using points sampled inside the mask. If given fixed and moving points the similarity metric value and the target registration errors will be displayed during registration. \n",
    "\n",
    "As this notebook performs intra-modal registration, we can readily use the Demons family of algorithms.\n",
    "\n",
    "We start by using the registration framework with SetMetricAsDemons. We use a multiscale approach which is readily available in the framework. We then illustrate how to use the Demons registration filters that are not part of the registration framework."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "demons_registration <- function(fixed_image, moving_image)\n",
    "{    \n",
    "    registration_method <- ImageRegistrationMethod()\n",
    "\n",
    "    # Create initial identity transformation.\n",
    "    transform_to_displacment_field_filter <- TransformToDisplacementFieldFilter()\n",
    "    transform_to_displacment_field_filter$SetReferenceImage(fixed_image)\n",
    "    # The image returned from the initial_transform_filter is transferred to the transform and cleared out.\n",
    "    initial_transform <- DisplacementFieldTransform(transform_to_displacment_field_filter$Execute(Transform()))\n",
    "    \n",
    "    # Regularization (update field - viscous, total field - elastic).\n",
    "    initial_transform$SetSmoothingGaussianOnUpdate(varianceForUpdateField=0.0, varianceForTotalField=2.0) \n",
    "    \n",
    "    registration_method$SetInitialTransform(initial_transform)\n",
    "\n",
    "    registration_method$SetMetricAsDemons(10) #intensities are equal if the difference is less than 10HU\n",
    "        \n",
    "    # Multi-resolution framework.            \n",
    "    registration_method$SetShrinkFactorsPerLevel(shrinkFactors = c(4,2,1))\n",
    "    registration_method$SetSmoothingSigmasPerLevel(smoothingSigmas = c(8,4,0))    \n",
    "\n",
    "    registration_method$SetInterpolator(\"sitkLinear\")\n",
    "    # If you have time, run this code using the ConjugateGradientLineSearch, otherwise run as is.   \n",
    "    #registration_method$SetOptimizerAsConjugateGradientLineSearch(learningRate=1.0, numberOfIterations=20, convergenceMinimumValue=1e-6, convergenceWindowSize=10)\n",
    "    registration_method$SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=20, convergenceMinimumValue=1e-6, convergenceWindowSize=10)\n",
    "    registration_method$SetOptimizerScalesFromPhysicalShift()\n",
    "        \n",
    "    return (registration_method$Execute(fixed_image, moving_image))\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Running the Demons registration with the conjugate gradient optimizer on this data <font color=\"red\">takes a long time</font> which is why the code above uses gradient descent. If you are more interested in accuracy and have the time then switch to the conjugate gradient optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Select the fixed and moving images, valid entries are in [1,10]\n",
    "fixed_image_index <- 1\n",
    "moving_image_index <- 8\n",
    "\n",
    "\n",
    "tx <- demons_registration(fixed_image = image_list[[fixed_image_index]], \n",
    "                          moving_image = image_list[[moving_image_index]])\n",
    "\n",
    "initial_errors <- registration_errors(Euler3DTransform(), points_list[[fixed_image_index]], points_list[[moving_image_index]])\n",
    "final_errors <- registration_errors(tx, points_list[[fixed_image_index]], points_list[[moving_image_index]])\n",
    "\n",
    "# Plot the TRE histograms before and after registration.\n",
    "df <- data.frame(AfterRegistration=final_errors, BeforeRegistration=initial_errors)\n",
    "df.long <- gather(df, key=ErrorType, value=ErrorMagnitude)\n",
    "\n",
    "ggplot(df.long, aes(x=ErrorMagnitude, group=ErrorType, colour=ErrorType, fill=ErrorType)) + \n",
    "geom_histogram(bins=20,position='identity', alpha=0.3) + \n",
    "theme(legend.title=element_blank(), legend.position=c(.85, .85))\n",
    "## Or, if you prefer density plots\n",
    "ggplot(df.long, aes(x=ErrorMagnitude, group=ErrorType, colour=ErrorType, fill=ErrorType)) + \n",
    "geom_density(position='identity', alpha=0.3) + \n",
    "theme(legend.title=element_blank(), legend.position=c(.85, .85))\n",
    "\n",
    "\n",
    "cat(paste0('Initial alignment errors in millimeters, mean(std): ',\n",
    "           sprintf('%.2f',mean(initial_errors)),'(',sprintf('%.2f',sd(initial_errors)),') max:', sprintf('%.2f\\n',max(initial_errors))))\n",
    "cat(paste0('Final alignment errors in millimeters, mean(std): ',\n",
    "           sprintf('%.2f',mean(final_errors)),'(',sprintf('%.2f',sd(final_errors)),') max:', sprintf('%.2f\\n',max(final_errors))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "simpleitk_error_allowed": "Exception in SITK"
   },
   "outputs": [],
   "source": [
    "# Transfer the segmentation via the estimated transformation. Use Nearest Neighbor interpolation to retain the labels.\n",
    "transformed_labels <- Resample(mask_list[[moving_image_index]],\n",
    "                               image_list[[fixed_image_index]],\n",
    "                               tx, \n",
    "                               \"sitkNearestNeighbor\",\n",
    "                               0.0, \n",
    "                               mask_list[[moving_image_index]]$GetPixelID())\n",
    "\n",
    "segmentations_before_and_after <- c(mask_list[[moving_image_index]], transformed_labels)\n",
    "\n",
    "# Look at the segmentation overlay before and after registration for a specific coronal slice\n",
    "coronal_index_registration_evaluation <- as.integer(round(image_list[[fixed_image_index]]$GetHeight()/2.0))\n",
    "temporal_slice <- temporal_coronal_with_overlay(coronal_index_registration_evaluation, \n",
    "                                                list(image_list[[fixed_image_index]], image_list[[fixed_image_index]]), \n",
    "                                                segmentations_before_and_after,\n",
    "                                                lung_label, -1024, 976)\n",
    "    # Flip the image so that it corresponds to the standard radiological display.\n",
    "Show(temporal_slice[,seq(temporal_slice$GetHeight(),0,-1),])  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SimpleITK also includes a set of Demons filters which are independent of the ImageRegistrationMethod. These include: \n",
    "1. DemonsRegistrationFilter\n",
    "2. DiffeomorphicDemonsRegistrationFilter\n",
    "3. FastSymmetricForcesDemonsRegistrationFilter\n",
    "4. SymmetricForcesDemonsRegistrationFilter\n",
    "\n",
    "As these filters are independent of the ImageRegistrationMethod we do not have access to the multiscale framework. Luckily it is easy to implement our own multiscale framework in SimpleITK, which is what we do in the next cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#    \n",
    "# Args:\n",
    "#        image: The image we want to resample.\n",
    "#        shrink_factor: A number greater than one, such that the new image's size is original_size/shrink_factor.\n",
    "#        smoothing_sigma: Sigma for Gaussian smoothing, this is in physical (image spacing) units, not pixels.\n",
    "#    Return:\n",
    "#        Image which is a result of smoothing the input and then resampling it using the given sigma and shrink factor.\n",
    "#\n",
    "smooth_and_resample <- function(image, shrink_factor, smoothing_sigma)\n",
    "{\n",
    "    smoothed_image <- SmoothingRecursiveGaussian(image, smoothing_sigma)\n",
    "    \n",
    "    original_spacing <- image$GetSpacing()\n",
    "    original_size <- image$GetSize()\n",
    "    new_size <-  as.integer(round(original_size/shrink_factor))\n",
    "    new_spacing <- (original_size-1)*original_spacing/(new_size-1)\n",
    "\n",
    "    return(Resample(smoothed_image, new_size, Transform(), \n",
    "                    \"sitkLinear\", image$GetOrigin(),\n",
    "                    new_spacing, image$GetDirection(), 0.0, \n",
    "                    image$GetPixelID()))\n",
    "}\n",
    "\n",
    "#    \n",
    "# Run the given registration algorithm in a multiscale fashion. The original scale should not be given as input as the\n",
    "# original images are implicitly incorporated as the base of the pyramid.\n",
    "# Args:\n",
    "#   registration_algorithm: Any registration algorithm that has an Execute(fixed_image, moving_image, displacement_field_image)\n",
    "#                           method.\n",
    "#   fixed_image: Resulting transformation maps points from this image's spatial domain to the moving image spatial domain.\n",
    "#   moving_image: Resulting transformation maps points from the fixed_image's spatial domain to this image's spatial domain.\n",
    "#   initial_transform: Any SimpleITK transform, used to initialize the displacement field.\n",
    "#   shrink_factors: Shrink factors relative to the original image's size.\n",
    "#   smoothing_sigmas: Amount of smoothing which is done prior to resampling the image using the given shrink factor. These\n",
    "#                     are in physical (image spacing) units.\n",
    "# Returns: \n",
    "#    DisplacementFieldTransform\n",
    "#\n",
    "multiscale_demons <- function(registration_algorithm, fixed_image, moving_image, initial_transform = NULL, \n",
    "                              shrink_factors=NULL, smoothing_sigmas=NULL)\n",
    "{    \n",
    "    # Create image pyramids. \n",
    "    fixed_images <- c(fixed_image, \n",
    "                      if(!is.null(shrink_factors))\n",
    "                          map2(rev(shrink_factors), rev(smoothing_sigmas), \n",
    "                               ~smooth_and_resample(fixed_image, .x, .y))\n",
    "                      )\n",
    "    moving_images <- c(moving_image, \n",
    "                       if(!is.null(shrink_factors))\n",
    "                           map2(rev(shrink_factors), rev(smoothing_sigmas), \n",
    "                               ~smooth_and_resample(moving_image, .x, .y))\n",
    "                       )\n",
    "\n",
    "    # Uncomment the following two lines if you want to see your image pyramids.\n",
    "    #lapply(fixed_images, Show)\n",
    "    #lapply(moving_images, Show)\n",
    "    \n",
    "                              \n",
    "    # Create initial displacement field at lowest resolution. \n",
    "    # Currently, the pixel type is required to be sitkVectorFloat64 because of a constraint imposed by the Demons filters.\n",
    "    lastImage <- fixed_images[[length(fixed_images)]]\n",
    "    if(!is.null(initial_transform))\n",
    "    {\n",
    "        initial_displacement_field = TransformToDisplacementField(initial_transform, \n",
    "                                                                  \"sitkVectorFloat64\",\n",
    "                                                                  lastImage$GetSize(),\n",
    "                                                                  lastImage$GetOrigin(),\n",
    "                                                                  lastImage$GetSpacing(),\n",
    "                                                                  lastImage$GetDirection())\n",
    "    }\n",
    "    else\n",
    "    {\n",
    "        initial_displacement_field <- Image(lastImage$GetWidth(), \n",
    "                                            lastImage$GetHeight(),\n",
    "                                            lastImage$GetDepth(),\n",
    "                                            \"sitkVectorFloat64\")\n",
    "        initial_displacement_field$CopyInformation(lastImage)\n",
    "    }\n",
    "    # Run the registration pyramid, run a registration at the top of the pyramid and then iterate: \n",
    "    # a. resampling previous deformation field onto higher resolution grid.\n",
    "    # b. register.\n",
    "    initial_displacement_field <- registration_algorithm$Execute(fixed_images[[length(fixed_images)]], \n",
    "                                                                moving_images[[length(moving_images)]], \n",
    "                                                                initial_displacement_field)\n",
    "    # This is a use case for a loop, because the operations depend on the previous step. Otherwise\n",
    "    # we need to mess around with tricky assignments to variables in different scopes\n",
    "    for (idx in seq(length(fixed_images)-1,1)) {\n",
    "        f_image <- fixed_images[[idx]]\n",
    "        m_image <- moving_images[[idx]]\n",
    "        initial_displacement_field <- Resample(initial_displacement_field, f_image)\n",
    "        initial_displacement_field <- registration_algorithm$Execute(f_image, m_image, initial_displacement_field)\n",
    "    }\n",
    "\n",
    "    return(DisplacementFieldTransform(initial_displacement_field))\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will use our newly minted multiscale framework to perform registration with the Demons filters. Some things you can easily try out by editing the code below:\n",
    "1. Is there really a need for multiscale - just call the multiscale_demons method without the shrink_factors and smoothing_sigmas parameters.\n",
    "2. Which Demons filter should you use - configure the other filters and see if our selection is the best choice (accuracy/time)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "fixed_image_index <- 1\n",
    "moving_image_index <- 8\n",
    "\n",
    "# Select a Demons filter and configure it.\n",
    "demons_filter <-  FastSymmetricForcesDemonsRegistrationFilter()\n",
    "demons_filter$SetNumberOfIterations(20)\n",
    "# Regularization (update field - viscous, total field - elastic).\n",
    "demons_filter$SetSmoothDisplacementField(TRUE)\n",
    "demons_filter$SetStandardDeviations(2.0)\n",
    "\n",
    "# Run the registration.\n",
    "tx <- multiscale_demons(registration_algorithm=demons_filter, \n",
    "                       fixed_image = image_list[[fixed_image_index]], \n",
    "                       moving_image = image_list[[moving_image_index]],\n",
    "                       shrink_factors = c(4,2),\n",
    "                       smoothing_sigmas = c(8,4))\n",
    "\n",
    "# Compare the initial and final TREs.\n",
    "initial_errors <- registration_errors(Euler3DTransform(), points_list[[fixed_image_index]], points_list[[moving_image_index]])\n",
    "final_errors <- registration_errors(tx, points_list[[fixed_image_index]], points_list[[moving_image_index]])\n",
    "\n",
    "# Plot the TRE histograms before and after registration.\n",
    "cat(paste0('Initial alignment errors in millimeters, mean(std): ',\n",
    "           sprintf('%.2f',mean(initial_errors)),'(',sprintf('%.2f',sd(initial_errors)),') max:', sprintf('%.2f\\n',max(initial_errors))))\n",
    "cat(paste0('Final alignment errors in millimeters, mean(std): ',\n",
    "           sprintf('%.2f',mean(final_errors)),'(',sprintf('%.2f',sd(final_errors)),') max:', sprintf('%.2f\\n',max(final_errors))))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "R",
   "language": "R",
   "name": "ir"
  },
  "language_info": {
   "codemirror_mode": "r",
   "file_extension": ".r",
   "mimetype": "text/x-r-source",
   "name": "R",
   "pygments_lexer": "r",
   "version": "3.2.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
